// Copyright (c) 2018 NVIDIA Corporation
// Author: Bryce Adelstein Lelbach <brycelelbach@gmail.com>
//
// Distributed under the Boost Software License v1.0 (boost.org/LICENSE_1_0.txt)

#pragma once

#include <thrust/detail/config.h>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header
#include <thrust/detail/allocator/allocator_traits.h>
#include <thrust/detail/memory_algorithms.h>
#include <thrust/detail/memory_wrapper.h>
#include <thrust/detail/raw_pointer_cast.h>
#include <thrust/detail/type_deduction.h>

#include <cuda/std/utility>

THRUST_NAMESPACE_BEGIN

// wg21.link/p0316r0

///////////////////////////////////////////////////////////////////////////////

template <typename T, typename Allocator, bool Uninitialized = false>
struct allocator_delete final
{
  using allocator_type =
    typename std::remove_cv<typename std::remove_reference<Allocator>::type>::type::template rebind<T>::other;
  using pointer = typename detail::allocator_traits<allocator_type>::pointer;

  template <typename UAllocator>
  allocator_delete(UAllocator&& other) noexcept
      : alloc_(THRUST_FWD(other))
  {}

  template <typename U, typename UAllocator>
  allocator_delete(allocator_delete<U, UAllocator> const& other) noexcept
      : alloc_(other.get_allocator())
  {}
  template <typename U, typename UAllocator>
  allocator_delete(allocator_delete<U, UAllocator>&& other) noexcept
      : alloc_(::cuda::std::move(other.get_allocator()))
  {}

  template <typename U, typename UAllocator>
  allocator_delete& operator=(allocator_delete<U, UAllocator> const& other) noexcept
  {
    alloc_ = other.get_allocator();
    return *this;
  }
  template <typename U, typename UAllocator>
  allocator_delete& operator=(allocator_delete<U, UAllocator>&& other) noexcept
  {
    alloc_ = ::cuda::std::move(other.get_allocator());
    return *this;
  }

  void operator()(pointer p)
  {
    using traits = detail::allocator_traits<::cuda::std::remove_cvref_t<Allocator>>;
    typename traits::allocator_type alloc_T(alloc_);

    if (nullptr != detail::pointer_traits<pointer>::get(p))
    {
      if constexpr (!Uninitialized)
      {
        traits::destroy(alloc_T, thrust::raw_pointer_cast(p));
      }
      traits::deallocate(alloc_T, p, 1);
    }
  }

  allocator_type& get_allocator() noexcept
  {
    return alloc_;
  }
  allocator_type const& get_allocator() const noexcept
  {
    return alloc_;
  }

  void swap(allocator_delete& other) noexcept
  {
    using ::cuda::std::swap;
    swap(alloc_, other.alloc_);
  }

private:
  allocator_type alloc_;
};

template <typename T, typename Allocator>
using uninitialized_allocator_delete = allocator_delete<T, Allocator, true>;

template <typename T, typename Allocator, bool Uninitialized = false>
struct array_allocator_delete final
{
  using allocator_type =
    typename std::remove_cv<typename std::remove_reference<Allocator>::type>::type::template rebind<T>::other;
  using pointer = typename detail::allocator_traits<allocator_type>::pointer;

  template <typename UAllocator>
  array_allocator_delete(UAllocator&& other, std::size_t n) noexcept
      : alloc_(THRUST_FWD(other))
      , count_(n)
  {}

  template <typename U, typename UAllocator>
  array_allocator_delete(array_allocator_delete<U, UAllocator> const& other) noexcept
      : alloc_(other.get_allocator())
      , count_(other.count_)
  {}
  template <typename U, typename UAllocator>
  array_allocator_delete(array_allocator_delete<U, UAllocator>&& other) noexcept
      : alloc_(::cuda::std::move(other.get_allocator()))
      , count_(other.count_)
  {}

  template <typename U, typename UAllocator>
  array_allocator_delete& operator=(array_allocator_delete<U, UAllocator> const& other) noexcept
  {
    alloc_ = other.get_allocator();
    count_ = other.count_;
    return *this;
  }
  template <typename U, typename UAllocator>
  array_allocator_delete& operator=(array_allocator_delete<U, UAllocator>&& other) noexcept
  {
    alloc_ = ::cuda::std::move(other.get_allocator());
    count_ = other.count_;
    return *this;
  }

  void operator()(pointer p)
  {
    using traits = detail::allocator_traits<::cuda::std::remove_cvref_t<Allocator>>;
    typename traits::allocator_type alloc_T(get_allocator());
    if (nullptr != detail::pointer_traits<pointer>::get(p))
    {
      if constexpr (!Uninitialized)
      {
        destroy_n(alloc_T, p, count_);
      }
      traits::deallocate(alloc_T, p, count_);
    }
  }

  allocator_type& get_allocator() noexcept
  {
    return alloc_;
  }
  allocator_type const& get_allocator() const noexcept
  {
    return alloc_;
  }

  void swap(array_allocator_delete& other) noexcept
  {
    using ::cuda::std::swap;
    swap(alloc_, other.alloc_);
    swap(count_, other.count_);
  }

private:
  allocator_type alloc_;
  std::size_t count_;
};

template <typename T, typename Allocator>
using uninitialized_array_allocator_delete = array_allocator_delete<T, Allocator, true>;

///////////////////////////////////////////////////////////////////////////////

template <typename Pointer, typename Lambda>
struct tagged_deleter : Lambda
{
  _CCCL_HOST_DEVICE tagged_deleter(Lambda&& l)
      : Lambda(THRUST_FWD(l))
  {}

  using pointer = Pointer;
};

template <typename Pointer, typename Lambda>
_CCCL_HOST_DEVICE tagged_deleter<Pointer, Lambda> make_tagged_deleter(Lambda&& l)
{
  return tagged_deleter<Pointer, Lambda>(THRUST_FWD(l));
}

///////////////////////////////////////////////////////////////////////////////

//! Creates a \p std::unique_ptr holding a new object of type \p T, constructed with \p args, using \p alloc as the
//! allocator.
template <typename T, typename Allocator, typename... Args>
_CCCL_HOST
std::unique_ptr<T,
                allocator_delete<T,
                                 typename detail::allocator_traits<
                                   ::cuda::std::remove_cvref_t<Allocator>>::template rebind_traits<T>::allocator_type>>
allocate_unique(Allocator const& alloc, Args&&... args)
{
  using traits = typename detail::allocator_traits<::cuda::std::remove_cvref_t<Allocator>>::template rebind_traits<T>;

  typename traits::allocator_type alloc_T(alloc);

  auto hold_deleter = make_tagged_deleter<typename traits::pointer>([&alloc_T](typename traits::pointer p) {
    traits::deallocate(alloc_T, p, 1);
  });
  using hold_t      = std::unique_ptr<T, decltype(hold_deleter)>;
  auto hold         = hold_t(traits::allocate(alloc_T, 1), hold_deleter);

  traits::construct(alloc_T, thrust::raw_pointer_cast(hold.get()), THRUST_FWD(args)...);
  auto deleter = allocator_delete<T, typename traits::allocator_type>(alloc);
  return std::unique_ptr<T, decltype(deleter)>(hold.release(), ::cuda::std::move(deleter));
}

//! Creates a \p std::unique_ptr holding storage for a new object of type \p T without constructing it, using \p alloc
//! as the allocator.
template <typename T, typename Allocator>
_CCCL_HOST std::unique_ptr<
  T,
  uninitialized_allocator_delete<
    T,
    typename detail::allocator_traits<::cuda::std::remove_cvref_t<Allocator>>::template rebind_traits<T>::allocator_type>>
uninitialized_allocate_unique(Allocator const& alloc)
{
  using traits = typename detail::allocator_traits<::cuda::std::remove_cvref_t<Allocator>>::template rebind_traits<T>;

  typename traits::allocator_type alloc_T(alloc);

  auto hold_deleter = make_tagged_deleter<typename traits::pointer>([&alloc_T](typename traits::pointer p) {
    traits::deallocate(alloc_T, p, 1);
  });
  using hold_t      = std::unique_ptr<T, decltype(hold_deleter)>;
  auto hold         = hold_t(traits::allocate(alloc_T, 1), hold_deleter);

  auto deleter = uninitialized_allocator_delete<T, typename traits::allocator_type>(alloc_T);
  return std::unique_ptr<T, decltype(deleter)>(hold.release(), ::cuda::std::move(deleter));
}

//! Creates a \p std::unique_ptr holding an array of objects of type \p T, each one constructed with \p args, using \p
//! alloc as the allocator.
template <typename T, typename Allocator, typename Size, typename... Args>
_CCCL_HOST std::unique_ptr<
  T[],
  array_allocator_delete<T,
                         typename detail::allocator_traits<typename std::remove_cv<typename std::remove_reference<
                           Allocator>::type>::type>::template rebind_traits<T>::allocator_type>>
allocate_unique_n(Allocator const& alloc, Size n, Args&&... args)
{
  using traits = typename detail::allocator_traits<::cuda::std::remove_cvref_t<Allocator>>::template rebind_traits<T>;

  typename traits::allocator_type alloc_T(alloc);

  auto hold_deleter = make_tagged_deleter<typename traits::pointer>([n, &alloc_T](typename traits::pointer p) {
    traits::deallocate(alloc_T, p, n);
  });
  using hold_t      = std::unique_ptr<T[], decltype(hold_deleter)>;
  auto hold         = hold_t(traits::allocate(alloc_T, n), hold_deleter);

  uninitialized_construct_n_with_allocator(alloc_T, hold.get(), n, THRUST_FWD(args)...);
  auto deleter = array_allocator_delete<T, typename traits::allocator_type>(alloc_T, n);
  return std::unique_ptr<T[], decltype(deleter)>(hold.release(), ::cuda::std::move(deleter));
}

//! Creates a \p std::unique_ptr holding storage for an array of objects of type \p T without constructing them, using
//! \p alloc as the allocator.
template <typename T, typename Allocator, typename Size>
_CCCL_HOST std::unique_ptr<
  T[],
  uninitialized_array_allocator_delete<
    T,
    typename detail::allocator_traits<::cuda::std::remove_cvref_t<Allocator>>::template rebind_traits<T>::allocator_type>>
uninitialized_allocate_unique_n(Allocator const& alloc, Size n)
{
  using traits = typename detail::allocator_traits<::cuda::std::remove_cvref_t<Allocator>>::template rebind_traits<T>;

  typename traits::allocator_type alloc_T(alloc);

  auto hold_deleter = make_tagged_deleter<typename traits::pointer>([n, &alloc_T](typename traits::pointer p) {
    traits::deallocate(alloc_T, p, n);
  });
  using hold_t      = std::unique_ptr<T[], decltype(hold_deleter)>;
  auto hold         = hold_t(traits::allocate(alloc_T, n), hold_deleter);

  auto deleter = uninitialized_array_allocator_delete<T, typename traits::allocator_type>(alloc_T, n);
  return std::unique_ptr<T[], decltype(deleter)>(hold.release(), ::cuda::std::move(deleter));
}

///////////////////////////////////////////////////////////////////////////////

THRUST_NAMESPACE_END
